{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Classification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\">\n",
    "\n",
    "This tutorial is available as an IPython notebook at [Malaya/example/zeroshot-classification](https://github.com/huseinzol05/Malaya/tree/master/example/zeroshot-classification).\n",
    "    \n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\">\n",
    "\n",
    "This module trained on both standard and local (included social media) language structures, so it is save to use for both.\n",
    "    \n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = ''\n",
    "os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.8/site-packages/bitsandbytes/cextension.py:34: UserWarning: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.\n",
      "  warn(\"The installed version of bitsandbytes was compiled without GPU support. \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.8/site-packages/bitsandbytes/libbitsandbytes_cpu.so: undefined symbol: cadam32bit_grad_fp32\n",
      "CPU times: user 3.27 s, sys: 3.22 s, total: 6.5 s\n",
      "Wall time: 2.64 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/dev/malaya/malaya/tokenizer.py:214: FutureWarning: Possible nested set at position 3397\n",
      "  self.tok = re.compile(r'({})'.format('|'.join(pipeline)))\n",
      "/home/husein/dev/malaya/malaya/tokenizer.py:214: FutureWarning: Possible nested set at position 3927\n",
      "  self.tok = re.compile(r'({})'.format('|'.join(pipeline)))\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "import malaya"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### what is zero-shot classification\n",
    "\n",
    "Commonly we supervised a machine learning on specific labels, negative / positive for sentiment, anger / happy / sadness for emotion and etc. The model cannot give an output if we want to know how much percentage of 'jealous' in emotion analysis model because supported labels are only {anger, happy, sadness}. Imagine, for example, trying to identify a text without ever having seen one 'jealous' label before, impossible. **So, zero-shot trying to solve this problem.**\n",
    "\n",
    "zero-shot learning refers to the process by which a machine learns how to recognize objects (image, text, any features) without any labeled training data to help in the classification.\n",
    "\n",
    "[Yin et al. (2019)](https://arxiv.org/abs/1909.00161) stated in his paper, any pretrained language model finetuned on text similarity actually can acted as an out-of-the-box zero-shot text classifier."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So, we are going to use transformer models from `malaya.similarity.semantic.huggingface` with a little tweaks."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### List available HuggingFace models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'mesolitica/finetune-mnli-nanot5-small': {'Size (MB)': 148,\n",
       "  'macro precision': 0.87125,\n",
       "  'macro recall': 0.87131,\n",
       "  'macro f1-score': 0.87127},\n",
       " 'mesolitica/finetune-mnli-nanot5-base': {'Size (MB)': 892,\n",
       "  'macro precision': 0.78903,\n",
       "  'macro recall': 0.79064,\n",
       "  'macro f1-score': 0.78918}}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "malaya.zero_shot.classification.available_huggingface"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load HuggingFace model\n",
    "\n",
    "```python\n",
    "def huggingface(\n",
    "    model: str = 'mesolitica/finetune-mnli-t5-small-standard-bahasa-cased',\n",
    "    force_check: bool = True,\n",
    "    **kwargs,\n",
    "):\n",
    "    \"\"\"\n",
    "    Load HuggingFace model to zeroshot text classification.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    model: str, optional (default='mesolitica/finetune-mnli-t5-small-standard-bahasa-cased')\n",
    "        Check available models at `malaya.zero_shot.classification.available_huggingface()`.\n",
    "    force_check: bool, optional (default=True)\n",
    "        Force check model one of malaya model.\n",
    "        Set to False if you have your own huggingface model.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    result: malaya.torch_model.huggingface.ZeroShotClassification\n",
    "    \"\"\"\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = malaya.zero_shot.classification.huggingface()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### predict batch\n",
    "\n",
    "```python\n",
    "def predict_proba(\n",
    "    self,\n",
    "    strings: List[str],\n",
    "    labels: List[str],\n",
    "    prefix: str = 'ayat ini berkaitan tentang',\n",
    "    multilabel: bool = True,\n",
    "):\n",
    "    \"\"\"\n",
    "    classify list of strings and return probability.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    strings: List[str]\n",
    "    labels: List[str]\n",
    "    prefix: str, optional (default='ayat ini berkaitan tentang')\n",
    "        prefix of labels to zero shot. Playing around with prefix can get better results.\n",
    "    multilabel: bool, optional (default=True)\n",
    "        probability of labels can be more than 1.0\n",
    "```\n",
    "\n",
    "Because it is a zero-shot, we need to give labels for the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# copy from twitter\n",
    "\n",
    "string = 'gov macam bengong, kami nk pilihan raya, gov backdoor, sakai'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'najib razak': 0.089086466,\n",
       "  'mahathir': 0.8503896,\n",
       "  'kerajaan': 0.31621307,\n",
       "  'PRU': 0.5521264,\n",
       "  'anarki': 0.018142236}]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict_proba([string], labels = ['najib razak', 'mahathir', 'kerajaan', 'PRU', 'anarki'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "string = 'tolong order foodpanda jab, lapar'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'makan': 0.9966216,\n",
       "  'makanan': 0.9912846,\n",
       "  'novel': 0.01200958,\n",
       "  'buku': 0.0026836568,\n",
       "  'kerajaan': 0.005800651,\n",
       "  'food delivery': 0.94829154}]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict_proba([string], labels = ['makan', 'makanan', 'novel', 'buku', 'kerajaan', 'food delivery'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "the model understood `order foodpanda` got close relationship with `makan`, `makanan` and `food delivery`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "string = 'kerajaan sebenarnya sangat prihatin dengan rakyat, bagi duit bantuan'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'makan': 0.0023917605,\n",
       "  'makanan': 0.002768525,\n",
       "  'novel': 0.0035945452,\n",
       "  'buku': 0.0028883144,\n",
       "  'kerajaan': 0.9981665,\n",
       "  'food delivery': 0.0029965744,\n",
       "  'kerajaan jahat': 0.95778364,\n",
       "  'kerajaan prihatin': 0.9981933,\n",
       "  'bantuan rakyat': 0.99804246}]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict_proba([string], labels = ['makan', 'makanan', 'novel', 'buku', 'kerajaan', 'food delivery',\n",
    "                                       'kerajaan jahat', 'kerajaan prihatin', 'bantuan rakyat'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### able to infer for mixed MS and EN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "string = 'Hi guys! I noticed semalam & harini dah ramai yang dapat cookies ni kan. So harini i nak share some post mortem of our first batch:'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'makan': 0.007691883,\n",
       "  'makanan': 0.997271,\n",
       "  'novel': 0.039510652,\n",
       "  'buku': 0.03565315,\n",
       "  'kerajaan': 0.0074525476,\n",
       "  'food delivery': 0.9393526,\n",
       "  'kerajaan jahat': 0.0053522647,\n",
       "  'kerajaan prihatin': 0.011083162,\n",
       "  'bantuan rakyat': 0.060150616,\n",
       "  'biskut': 0.9302781,\n",
       "  'very helpful': 0.07355973,\n",
       "  'sharing experiences': 0.9778896,\n",
       "  'sharing session': 0.014371477}]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict_proba([string], labels = ['makan', 'makanan', 'novel', 'buku', 'kerajaan', 'food delivery',\n",
    "                                       'kerajaan jahat', 'kerajaan prihatin', 'bantuan rakyat',\n",
    "                                       'biskut', 'very helpful', 'sharing experiences',\n",
    "                                       'sharing session'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'makan': 0.0014243807,\n",
       "  'makanan': 0.004838416,\n",
       "  'novel': 0.0019961353,\n",
       "  'buku': 0.003897282,\n",
       "  'kerajaan': 0.004189471,\n",
       "  'food delivery': 0.97480994,\n",
       "  'kerajaan jahat': 0.0018161167,\n",
       "  'kerajaan prihatin': 0.0054033417,\n",
       "  'bantuan rakyat': 0.0054734466,\n",
       "  'biskut': 0.018219633,\n",
       "  'very helpful': 0.03659028,\n",
       "  'sharing experiences': 0.98463523,\n",
       "  'sharing session': 0.013350475}]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict_proba([string], labels = ['makan', 'makanan', 'novel', 'buku', 'kerajaan', 'food delivery',\n",
    "                                       'kerajaan jahat', 'kerajaan prihatin', 'bantuan rakyat',\n",
    "                                       'biskut', 'very helpful', 'sharing experiences',\n",
    "                                       'sharing session'],\n",
    "                   prefix = 'teks ini berkaitan tentang')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multiclasses but not multilabel\n",
    "\n",
    "Sum of probability equal to 1.0, so to do that, set `multilabel=False`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'makan': 0.00062935066,\n",
       "  'makanan': 0.00067746383,\n",
       "  'novel': 0.0007715335,\n",
       "  'buku': 0.0006922778,\n",
       "  'kerajaan': 0.2833456,\n",
       "  'food delivery': 0.0007045073,\n",
       "  'kerajaan jahat': 0.05875754,\n",
       "  'kerajaan prihatin': 0.28552753,\n",
       "  'bantuan rakyat': 0.27457199,\n",
       "  'biskut': 0.0007160352,\n",
       "  'very helpful': 0.09099287,\n",
       "  'sharing experiences': 0.0012673552,\n",
       "  'sharing session': 0.0013456849}]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "string = 'kerajaan sebenarnya sangat prihatin dengan rakyat, bagi duit bantuan'\n",
    "\n",
    "model.predict_proba([string], labels = ['makan', 'makanan', 'novel', 'buku', 'kerajaan', 'food delivery',\n",
    "                                       'kerajaan jahat', 'kerajaan prihatin', 'bantuan rakyat',\n",
    "                                       'biskut', 'very helpful', 'sharing experiences',\n",
    "                                       'sharing session'], multilabel = False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Stacking models\n",
    "\n",
    "More information, you can read at https://malaya.readthedocs.io/en/latest/Stack.html\n",
    "\n",
    "If you want to stack zero-shot classification models, you need to pass labels using keyword parameter,\n",
    "\n",
    "```python\n",
    "malaya.stack.predict_stack([model1, model2], List[str], labels = List[str])\n",
    "```\n",
    "\n",
    "We will passed `labels` as `**kwargs`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'makan': 0.0023917593,\n",
       "  'makanan': 0.002768525,\n",
       "  'novel': 0.0035945452,\n",
       "  'buku': 0.0028883128,\n",
       "  'kerajaan': 0.9981665,\n",
       "  'food delivery': 0.0029965725,\n",
       "  'kerajaan jahat': 0.95778376,\n",
       "  'kerajaan prihatin': 0.9981934,\n",
       "  'bantuan rakyat': 0.9980425,\n",
       "  'comel': 0.0031943405,\n",
       "  'kerajaan syg sgt kepada rakyat': 0.99586475}]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "string = 'kerajaan sebenarnya sangat prihatin dengan rakyat, bagi duit bantuan'\n",
    "labels = ['makan', 'makanan', 'novel', 'buku', 'kerajaan', 'food delivery', \n",
    " 'kerajaan jahat', 'kerajaan prihatin', 'bantuan rakyat', 'comel', 'kerajaan syg sgt kepada rakyat']\n",
    "malaya.stack.predict_stack([model, model, model], [string], \n",
    "                           labels = labels)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
